"""
This file contains method for generating calibration related plots, eg. reliability plots.
References:
[1] C. Guo, G. Pleiss, Y. Sun, and K. Q. Weinberger. On calibration of modern neural networks.
    arXiv preprint arXiv:1706.04599, 2017.
"""
import imp
import os
import math
import matplotlib.pyplot as plt
import wandb
import numpy as np
from numpy import cov, expand_dims, linalg, atleast_2d
from sklearn.metrics import auc
from scipy import stats
import torch
import torch.nn.functional as F
from loss_functions.auc_loss_bw import (
    get_thresholds_from_cdf,
    get_thresholds_from_cdf_np,
)

plt.rcParams.update({"font.size": 20})

# Some keys used for the following dictionaries
COUNT = "count"
CONF = "conf"
ACC = "acc"
BIN_ACC = "bin_acc"
BIN_CONF = "bin_conf"


def _bin_initializer(bin_dict, num_bins=10):
    for i in range(num_bins):
        bin_dict[i][COUNT] = 0
        bin_dict[i][CONF] = 0
        bin_dict[i][ACC] = 0
        bin_dict[i][BIN_ACC] = 0
        bin_dict[i][BIN_CONF] = 0


def _populate_bins(confs, preds, labels, num_bins=10):
    bin_dict = {}
    for i in range(num_bins):
        bin_dict[i] = {}
    _bin_initializer(bin_dict, num_bins)
    num_test_samples = len(confs)

    for i in range(0, num_test_samples):
        confidence = confs[i]
        prediction = preds[i]
        label = labels[i]
        binn = int(math.ceil(((num_bins * confidence) - 1)))
        bin_dict[binn][COUNT] = bin_dict[binn][COUNT] + 1
        bin_dict[binn][CONF] = bin_dict[binn][CONF] + confidence
        bin_dict[binn][ACC] = bin_dict[binn][ACC] + (1 if (label == prediction) else 0)

    for binn in range(0, num_bins):
        if bin_dict[binn][COUNT] == 0:
            bin_dict[binn][BIN_ACC] = 0
            bin_dict[binn][BIN_CONF] = 0
        else:
            bin_dict[binn][BIN_ACC] = float(bin_dict[binn][ACC]) / bin_dict[binn][COUNT]
            bin_dict[binn][BIN_CONF] = bin_dict[binn][CONF] / float(
                bin_dict[binn][COUNT]
            )
    return bin_dict


def reliability_plot(confs, preds, labels, plot_name, num_bins=15):
    """
    Method to draw a reliability plot from a model's predictions and confidences.
    """
    bin_dict = _populate_bins(confs, preds, labels, num_bins)
    bns = [(i / float(num_bins)) for i in range(num_bins)]
    y = []
    for i in range(num_bins):
        y.append(bin_dict[i][BIN_ACC])
    plt.figure(figsize=(10, 8))  # width:20, height:3
    plt.bar(bns, bns, align="edge", width=0.05, color="pink", label="Expected")
    plt.bar(bns, y, align="edge", width=0.05, color="blue", alpha=0.5, label="Actual")
    plt.ylabel("Accuracy")
    plt.xlabel("Confidence")
    plt.legend()
    plt.savefig(plot_name + ".png")


def bin_strength_plot(confs, preds, labels, plot_name, num_bins=15):
    """
    Method to draw a plot for the number of samples in each confidence bin.
    """
    bin_dict = _populate_bins(confs, preds, labels, num_bins)
    bns = [(i / float(num_bins)) for i in range(num_bins)]
    num_samples = len(labels)
    y = []
    for i in range(num_bins):
        n = (bin_dict[i][COUNT] / float(num_samples)) * 100
        y.append(n)
    plt.figure(figsize=(10, 8))  # width:20, height:3
    plt.bar(
        bns,
        y,
        align="edge",
        width=0.05,
        color="blue",
        alpha=0.5,
        label="Percentage samples",
    )
    plt.ylabel("Percentage of samples")
    plt.xlabel("Confidence")
    plt.savefig(plot_name + ".png")


def roc_no_decision(confidences, labels, settings, plot_name):
    """
    Plots of 1 - % of samples you take a decision for VS 1-error.
    """
    size_dataset = labels.size
    r = np.amax(confidences, axis=1)
    thresholds_confidence = get_thresholds_from_cdf_np(r)
    predictions = np.argmax(confidences, axis=1)

    samples_without_decision = np.zeros_like(thresholds_confidence)
    one_minus_error = np.zeros_like(thresholds_confidence)

    for i in range(thresholds_confidence.size):
        num = np.sum(
            np.logical_and(
                np.asarray(r >= thresholds_confidence[i]),
                np.not_equal(labels, predictions),
            )
        )
        denum = np.sum(np.asarray(r >= thresholds_confidence[i])) + 1e-10

        one_minus_error[i] = 1.0 - (num / denum).item()

        samples_without_decision[i] = (
            np.sum(np.asarray(r < thresholds_confidence[i])) / size_dataset
        ).item()

    # print(one_minus_error[-100:], samples_without_decision[-100:])
    fig = plt.figure(figsize=(10, 8))
    auc_plot = auc(samples_without_decision, one_minus_error)
    plt.plot(samples_without_decision, one_minus_error, label="AUC = %0.3f" % auc_plot)
    plt.ylabel("1 - error (%)")
    plt.xlabel("Samples without decision (%)")
    plt.ylim([0.9, 1.00])
    plt.yticks(np.arange(0.9, 1.0, step=0.05))
    plt.grid(ls="--", lw=0.5, markevery=0.05)
    plt.legend(loc="lower right")
    plt.savefig(plot_name + ".png")
    if settings.plot_together == 1:
        settings.np_all_one_minus_error[:, settings.count] = one_minus_error
        settings.np_all_undecided[:, settings.count] = samples_without_decision
    # np.savetxt(
    #     settings.plots_dir + "/values_{}".format(settings.model_name) + ".csv",
    #     np.concatenate(
    #         (
    #             np.expand_dims(thresholds_confidence, axis=1),
    #             np.expand_dims(samples_without_decision, axis=1),
    #             np.expand_dims(one_minus_error, axis=1),
    #         ),
    #         axis=1,
    #     ),
    #     fmt="%f",
    #     newline="\n",
    #     delimiter=",",
    #     header="hresholds_confidence, samples_without_decision, one_minus_error",
    # )


def plot_roc_together(settings, checkpoint_file):
    """
    Plots of 1 - % of samples you take a decision for VS 1-error for multiple models together.
    """
    list_colors = ["b", "g", "r", "c", "y", "m", "k", "orange", "purple"]
    if "best_acc" in checkpoint_file:
        suffix = "_best_acc"
    elif "best_auc" in checkpoint_file:
        suffix = "_best_auc"

    plot_name = os.path.join(
        "/root",
        settings.project_name,
        settings.dataset,
        settings.net_type,
        "compare_roc" + suffix + "_FL_TS.png",
    )

    plt.figure(figsize=(10, 8))
    for i in range(len(settings.loss_configs_array)):
        samples_without_decision = settings.np_all_undecided[:, i]
        one_minus_error = settings.np_all_one_minus_error[:, i]
        auc_plot = auc(samples_without_decision, one_minus_error)
        auc_plot = auc_plot * 100.0
        # print(samples_without_decision, one_minus_error)
        plt.plot(
            samples_without_decision * 100.00,
            one_minus_error * 100.00,
            color=list_colors[i],
            label="{}, AUCOC {:.2f}".format(
                settings.loss_type_array[i],
                auc_plot,
                # settings.use_temperature_scaling_array[i],
            ),
        )
        if samples_without_decision[-1] < 0.99000:
            plt.plot(
                [
                    samples_without_decision[-1] * 100.00,
                    samples_without_decision[-1] * 100.00,
                ],
                [one_minus_error[-1] * 100.00, 0],
                color=list_colors[i],
                linestyle="dashed",
                alpha=0.7,
            )
    plt.ylabel("E" + r"$[c|r>r_0]$" + ": accuracy of the network (%)")
    plt.xlabel(r"$\tau_0$" + ": samples to be analysed manually (%)")
    plt.ylim([70.00, 100.00])
    plt.xlim([0.0, 100.00])
    plt.yticks(np.arange(70.00, 100.00, step=10.00))
    plt.grid(ls="--", lw=0.5, markevery=10.00)
    plt.legend(loc="lower right", fontsize="small")
    # plt.title("COC curves")

    plt.savefig(plot_name)


def get_cdf(settings, confidences):
    """
    Function to get the CDF of a samples distribution.
    """
    plot_name = os.path.join(
        settings.plots_dir,
        "cdf.png",
    )

    x = np.sort(np.amax(confidences, axis=1))
    # get the cdf values of y
    y = np.arange(x.size) / float(x.size)

    # plotting
    plt.xlabel("values")
    plt.ylabel("CDF")

    plt.plot(x, y)
    plt.xlim([0.00, 1.00])
    plt.savefig(plot_name)
    np.savetxt(
        settings.plots_dir + "/confidences.csv",
        np.sort(np.amax(confidences, axis=1)),
        fmt="%f",
        newline="\n",
        delimiter=",",
        header="confidences",
    )


# def get_thresholds_from_cdf(settings, confidences):
#     """
#     Given a distribution of confidences, compute the thresholds to generate the ROC plots.
#     """
#     thresholds = np.zeros(settings.num_thresholds)
#     top_confidences = np.amax(confidences, axis=1)
#     confidencs_sorted = np.sort(top_confidences)
#     # get the cdf confidencs_sorted of y
#     cdf = np.arange(confidencs_sorted.size) / float(confidencs_sorted.size)
#     if settings.num_thresholds > confidencs_sorted.size:
#         samples_uniform_cdf = np.linspace(0, 1, settings.num_thresholds)
#         thresholds = np.interp(samples_uniform_cdf, cdf, confidencs_sorted)
#     else:
#         index_samples_uniform_cdf = np.arange(
#             0,
#             top_confidences.size,
#             step=int(top_confidences.size / settings.num_thresholds),
#         )

#         samples_uniform_cdf = cdf[index_samples_uniform_cdf]
#         j = 0
#         for i in samples_uniform_cdf:
#             thresholds[j] = confidencs_sorted[np.where(cdf == i)[0]]
#             j += 1
#     return thresholds


def my_kde(r, rn, alpha):

    precision = (linalg.inv(atleast_2d(cov(rn, rowvar=1, bias=False)))) / alpha ** 2
    dtype = np.common_type(precision, rn)
    whitening = np.linalg.cholesky(precision).astype(dtype, copy=False)
    # print(atleast_2d(cov(rn, rowvar=1, bias=False)), 1.0 / atleast_2d(cov(rn, rowvar=1, bias=False)), precision * alpha**2)
    points_ = np.dot(expand_dims(rn, 1), whitening).astype(dtype, copy=False)

    xi_ = np.dot(expand_dims(r, 1), whitening).astype(dtype, copy=False)

    # Evaluate the normalisation
    d = 1
    norm = math.pow((2 * math.pi), (-d / 2))
    norm *= whitening[0, 0]

    # Create the result array and evaluate the weighted sum
    estimate = np.zeros((r.shape[0], 1), dtype)
    for i in range(rn.shape[0]):
        for j in range(r.shape[0]):
            arg = 0
            for k in range(d):
                residual = points_[i, k] - xi_[j, k]
                arg += residual * residual

            estimate[j, 0] += math.exp(-arg / 2) * norm * 1 / (rn.shape[0])

    return np.squeeze(estimate)


def hist_vs_kde(train_outputs):
    confidences = F.softmax(train_outputs, dim=1)
    r_torch, _ = torch.max(confidences, 1)
    r = r_torch.cpu().detach().numpy()
    pdf = stats.gaussian_kde(r, bw_method=None, weights=None)
    # pdb.set_trace()
    alpha = pdf.factor

    x = np.linspace(-0.1, 1.1, 200)
    y = pdf(x)
    y_hat = my_kde(x, r, alpha)
    r0 = get_thresholds_from_cdf(confidences).cpu().detach().numpy()
    # print(r0, np.min(r), np.max(r))
    t0 = np.zeros_like(r0)
    r = np.sort(r)
    # for l in range(r.shape[0]):
    #     t0[l] = pdf.integrate_box_1d(0.0, r[l])
    #     print(t0[l])

    plt.hist(
        r,
        20,
        (-0.1, 1.1),
        histtype="stepfilled",
        density=True,
        alpha=0.2,
        color="k",
        label="histogram",
    )
    plt.plot(x, y, label="kde")
    # plt.plot(x, y_hat, label="kde")
    plt.scatter(r, np.zeros_like(r), marker="x", color="k", alpha=0.1, label="samples")
    plt.savefig("kde_hist_new.png")
    print("plot done")

    # pdb.set_trace()
